package sess

import (
	"encoding/binary"
	"errors"
	"fmt"

	"gitee.com/night-tc/gobige/logger"
	"gitee.com/night-tc/gobige/msgdef"
	"gitee.com/night-tc/gobige/pool"
	"github.com/golang/snappy"
)

/*
消息定义文件
消息格式


流式结构如下：
消息头| 消息体 | 消息头 | 消息体

消息头 分为2部分：3个byte表示消息长度，1个byte表示是否压缩

消息体 分为2部分：2个byte表示这消息号，后面长度表示数据

*/

// GetMsgID 获取消息ID
func GetMsgID(buf []byte) uint16 {
	if len(buf) < MsgIDSize {
		return 0
	}

	return uint16(buf[0]) | uint16(buf[1])<<8
}

//var encrypt_key = []byte{253, 1, 56, 52, 62, 176, 42, 138}
//var decrypt_key = []byte{41, 253, 1, 56, 52, 62, 176, 42}

var server_key = []byte{253, 1, 56, 52, 62, 176, 42, 138}
var client_key = []byte{41, 253, 1, 56, 52, 62, 176, 42}

// 设置服务器编码的密钥
func SetServer_key(bytes []byte) {
	server_key = bytes
}

// 获取服务器编码的密钥
func GetServer_key() []byte {
	return server_key
}

// 设置客户端编码的密钥
func SetClient_key(bytes []byte) {
	client_key = bytes
}

// 获取客户端编码的密钥
func GetClient_key() []byte {
	return client_key
}

// EncryptData 服务器加密算法
func EncryptData(buf []byte) []byte {
	buflen := len(buf)
	key := GetServer_key()
	keylen := len(key)

	for i := 0; i < buflen; i++ {
		n := byte(i%7 + 1)                       //移位长度(1-7)
		b := (buf[i] << n) | (buf[i] >> (8 - n)) // 向左循环移位

		buf[i] = b ^ key[i%keylen]
	}

	return buf
}

// DecryptData 服务器解密算法
func DecryptData(buf []byte) []byte {

	buflen := len(buf)
	key := GetClient_key()
	keylen := len(key)

	for i := 0; i < buflen; i++ {

		b := buf[i] ^ key[i%keylen]

		n := byte(i%7 + 1)                 //移位长度(1-7)
		buf[i] = (b >> n) | (b << (8 - n)) // 向右循环移位
	}
	return buf
}

// DecryptDataByClient 模拟客户端解密算法
func DecryptDataByClient(buf []byte) []byte {

	buflen := len(buf)
	key := GetServer_key()
	keylen := len(key)

	for i := 0; i < buflen; i++ {

		b := buf[i] ^ key[i%keylen]

		n := byte(i%7 + 1)                 //移位长度(1-7)
		buf[i] = (b >> n) | (b << (8 - n)) // 向右循环移位
	}
	return buf
}

// DecodeMsg 返回消息名称及反序列化后的消息对象
func DecodeMsgByMsgID(flag byte, buf []byte, msgID uint32) (msgContent msgdef.IMsg, err error) {
	var msgBody []byte
	if msgID == 0 {
		if len(buf) < MsgIDSize {
			logger.Error("数据格式错误, buf:", len(buf))
			return nil, errors.New("长度错误")
		}
		_, msgContent, err = msgdef.GetTypeMgr().GetMsgInfo(GetMsgID(buf))
		if err != nil {
			return nil, err
		}
		msgBody = buf[MsgIDSize:]
	} else {
		_, msgContent, err = msgdef.GetTypeMgr().GetMsgInfo(uint16(msgID))
		if err != nil {
			return nil, err
		}
		msgBody = buf
	}

	encryptFlag := flag & SessFlag_Encrypt
	if encryptFlag > 0 {
		msgBody = DecryptData(msgBody)
	}

	compressFlag := flag & SessFlag_Compress
	if compressFlag > 0 {
		msgBuf := pool.Get(MaxMsgBuffer) //make([]byte, MaxMsgBuffer)
		defer pool.Put(msgBuf)
		// modify by zgb 2020/09/30 这里原来在snappy里加了一个函数 用来检测解压出来的包大小是否超过最大值，其实不需要，因为在readARQMsgForward前面就已经检查过了
		msgBuf, err = snappy.Decode(msgBuf, msgBody)
		if err != nil {
			return nil, err
		}
		msgBody = msgBuf
	}
	if err = msgContent.Unmarshal(msgBody); err != nil {
		return nil, err
	}

	return msgContent, nil
}

// DecodeMsg 返回消息名称及反序列化后的消息对象
func DecodeMsg(flag byte, buf []byte) (msgdef.IMsg, error) {

	if len(buf) < MsgIDSize {
		logger.Error("数据格式错误, buf:", len(buf))
		return nil, errors.New("长度错误")
	}
	_, msgContent, err := msgdef.GetTypeMgr().GetMsgInfo(GetMsgID(buf))
	if err != nil {
		return nil, err
	}

	msgBody := buf[MsgIDSize:]

	encryptFlag := flag & SessFlag_Encrypt
	if encryptFlag > 0 {
		msgBody = DecryptData(msgBody)
	}

	compressFlag := flag & SessFlag_Compress
	if compressFlag > 0 {
		msgBuf := pool.Get(MaxMsgBuffer) //make([]byte, MaxMsgBuffer)
		defer pool.Put(msgBuf)
		// modify by zgb 2020/09/30 这里原来在snappy里加了一个函数 用来检测解压出来的包大小是否超过最大值，其实不需要，因为在readARQMsgForward前面就已经检查过了
		msgBuf, err = snappy.Decode(msgBuf, msgBody)
		if err != nil {
			return nil, err
		}
		msgBody = msgBuf
	}
	if err = msgContent.Unmarshal(msgBody); err != nil {
		return nil, err
	}

	return msgContent, nil
}

/*
EncodeMsg 将消息编码为字节流，可选择禁用压缩但不启用加密

参数：
  - msg: 要编码的消息实例，必须实现 IMsg 接口
  - buf: 用于编码的缓冲区，可复用或预分配内存。若为nil会自动分配
  - forceNoCompress: 为true时强制禁用压缩，即使消息默认启用压缩

返回值：
  - []byte: 编码后的字节数据，可能引用传入的buf或新分配的内存
  - error: 编码过程中遇到的错误，成功时为nil

说明：

	本函数通过调用 EncodeMsgWithEncrypt 实现，固定不启用加密功能（第四个参数设为false）
	如需加密功能请直接调用 EncodeMsgWithEncrypt
*/
func EncodeMsg(msg msgdef.IMsg, buf []byte, forceNoCompress bool) ([]byte, error) {
	return EncodeMsgWithEncrypt(msg, buf, forceNoCompress, false)
}

// EncodeMsgWithEncrypt 对消息进行编码处理，支持压缩和加密选项
// 参数:
//
//	msg: 需要编码的消息对象，必须实现msgdef.IMsg接口
//	buf: 用于存储编码结果的缓冲区，允许传入现有缓冲区减少内存分配
//	forceNoCompress: 强制禁用压缩的标志，true表示不进行压缩
//	encryptEnabled: 启用加密的标志，true表示需要对消息内容进行加密
//
// 返回值:
//
//	[]byte: 编码后的字节数据（包含可能的压缩和加密处理）
//	error: 处理过程中遇到的错误信息
func EncodeMsgWithEncrypt(msg msgdef.IMsg, buf []byte, forceNoCompress bool, encryptEnabled bool) ([]byte, error) {

	// 基础参数校验：确保消息对象有效
	if msg == nil {
		return nil, errors.New("消息错误，消息不能为nil")
	}

	// 通过消息类型管理器获取消息ID
	msgID, err := msgdef.GetTypeMgr().GetMsgIDByName(msg.Name())
	if err != nil {
		return nil, err
	}

	var msgbuf []byte

	// 根据消息大小决定压缩策略：
	// 当消息大小达到压缩阈值且未强制禁用压缩时，使用snappy压缩算法
	size := msg.Size()
	if size >= minCompressSize && !forceNoCompress {
		msgbuf, err = _snappyCompressCmd(msgID, msg, size, buf)
		if err != nil {
			return nil, err
		}
	} else {
		// 执行非压缩编码流程（当不满足压缩条件时）
		msgbuf, err = _noCompressCmd(msgID, msg, size, buf)
		if err != nil {
			return nil, err
		}
	}

	// 加密处理模块：
	// 当启用加密时，对消息体进行加密并设置消息头标志位
	if encryptEnabled {
		data := msgbuf[MsgHeadSize:]
		EncryptData(data)
		msgbuf[3] = msgbuf[3] | SessFlag_Encrypt // 设置加密标志位（第2位）
	}

	return msgbuf, err
}

// 对二进制数据进行压缩
func SnappyCompressBytes(olddata, buf []byte) ([]byte, error) {
	cmd := GetMsgID(olddata[MsgHeadSize-MsgIDSize:])
	data := olddata[MsgHeadSize:]
	maxLen := snappy.MaxEncodedLen(len(data))
	if maxLen+MsgHeadSize > len(buf) {
		return nil, fmt.Errorf("SnappyCompressBytes message size too large msgSize %d limitSize %d ",
			maxLen+MsgHeadSize, len(buf))
	}
	p := buf[:maxLen+MsgHeadSize]
	mbuff := snappy.Encode(p[MsgHeadSize:], data)
	cmdsize := len(mbuff) + MsgIDSize
	p[0] = byte(cmdsize)
	p[1] = byte(cmdsize >> 8)
	p[2] = byte(cmdsize >> 16)
	p[3] = p[3] | SessFlag_Compress // 设置压缩标志位（第1位）
	binary.LittleEndian.PutUint16(p[4:], cmd)
	return p[:len(mbuff)+MsgHeadSize], nil
}

// 带压缩算法的编码
func _snappyCompressCmd(cmd uint16, msg msgdef.IMsg, msgSize int, buf []byte) ([]byte, error) {

	msgdata := make([]byte, msgSize)

	n, err := msg.MarshalTo(msgdata)
	if err != nil {
		logger.Error("[协议] 编码错误 ", err)
		return nil, err
	}
	data := msgdata[:n]
	maxLen := snappy.MaxEncodedLen(len(data))

	if maxLen+MsgHeadSize > len(buf) {
		return nil, fmt.Errorf("message size too large msgSize %d limitSize %d cmd %d msgName %s",
			msgSize+MsgHeadSize, len(buf), cmd, msg)
	}

	p := buf[:maxLen+MsgHeadSize]

	mbuff := snappy.Encode(p[MsgHeadSize:], data)
	cmdsize := len(mbuff) + MsgIDSize
	p[0] = byte(cmdsize)
	p[1] = byte(cmdsize >> 8)
	p[2] = byte(cmdsize >> 16)
	p[3] = p[3] | SessFlag_Compress // 设置压缩标志位（第1位）
	binary.LittleEndian.PutUint16(p[4:], cmd)
	return p[:len(mbuff)+MsgHeadSize], nil
}

// 不带压缩算法的编码
func _noCompressCmd(cmd uint16, msg msgdef.IMsg, msgSize int, buf []byte) ([]byte, error) {

	if msgSize+MsgHeadSize > len(buf) {
		return nil, fmt.Errorf("message size too large msgSize %d limitSize %d cmd %d msgName %s",
			msgSize+MsgHeadSize, len(buf), cmd, msg)
	}

	data := buf
	n, err := msg.MarshalTo(data[MsgHeadSize:])
	if err != nil {
		logger.Error("[协议] 编码错误 ", err)
		return nil, err
	}
	cmdsize := n + MsgIDSize
	data[0] = byte(cmdsize)
	data[1] = byte(cmdsize >> 8)
	data[2] = byte(cmdsize >> 16)
	data[3] = 0
	binary.LittleEndian.PutUint16(data[4:], cmd)
	return data[:n+MsgHeadSize], nil
}

// 发给客户端的二进制,用于调试
func DecodeBytes(data []byte) (msg msgdef.IMsg, err error) {
	flag, buf := data[3], data[4:]
	if len(buf) < MsgIDSize {
		return nil, Err_Data_Len.NewErr(nil, len(buf))
	}
	_, msgContent, err := msgdef.GetTypeMgr().GetMsgInfo(GetMsgID(buf))
	if err != nil {
		return nil, err
	}

	msgBody := buf[MsgIDSize:]

	encryptFlag := flag & SessFlag_Encrypt
	if encryptFlag > 0 {
		msgBody = DecryptDataByClient(msgBody)
	}

	compressFlag := flag & SessFlag_Compress
	if compressFlag == 0 {
		if err = msgContent.Unmarshal(msgBody); err != nil {
			return nil, err
		}
	} else {
		msgBuf := pool.Get(MaxMsgBuffer) //make([]byte, MaxMsgBuffer)
		defer pool.Put(msgBuf)
		unCompressBuf, err := snappy.Decode(msgBuf, msgBody)
		if err != nil {
			return nil, err
		}
		if err = msgContent.Unmarshal(unCompressBuf); err != nil {
			return nil, err
		}
	}

	return msgContent, nil
}
